{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaa5b1aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import f1_score, confusion_matrix, accuracy_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from model_ALT import MyCatBoostModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2eaeba5a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'CONTROL', 'BIPOLAR', 'SCHIZO', 'DEPRESSION'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_631673/3694097567.py:18: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
      "  my_dataset = pd.concat(\n"
     ]
    }
   ],
   "source": [
    "cols = ['diagnosis', 'bpm', 'sdnn', 'rmssd', 'hr_mad', 'hf', 'rsa']\n",
    "\n",
    "my_dataset = pd.DataFrame(columns=cols)\n",
    "my_disorders = set()\n",
    "\n",
    "data_path = os.getcwd() + '/data_ALT'\n",
    "\n",
    "for data_file in os.listdir(data_path):\n",
    "    dataset = pd.read_csv(data_path + '/' + data_file)\n",
    "\n",
    "    dataset = dataset.dropna()\n",
    "    dataset = dataset.reindex(columns=cols)\n",
    "\n",
    "    diagnosis = data_file.split('__')[0]\n",
    "    my_disorders.add(diagnosis)\n",
    "    dataset['diagnosis'] = diagnosis\n",
    "\n",
    "    my_dataset = pd.concat(\n",
    "        [my_dataset, dataset],\n",
    "    ignore_index=True)\n",
    "\n",
    "print(my_disorders)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3e45aea8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(162, 6)\n",
      "(162,)\n",
      "['BIPOLAR' 'CONTROL' 'DEPRESSION' 'SCHIZO']\n"
     ]
    }
   ],
   "source": [
    "train, test = train_test_split(\n",
    "    my_dataset.to_numpy(),\n",
    "    test_size=0.2,\n",
    "    stratify=my_dataset[['diagnosis']].to_numpy(),\n",
    "    shuffle=True,\n",
    "    random_state=190525\n",
    ")\n",
    "train, val = train_test_split(\n",
    "    train,\n",
    "    test_size=0.1,\n",
    "    stratify=train[:, 0],\n",
    "    shuffle=True,\n",
    "    random_state=190525\n",
    ")\n",
    "\n",
    "train_X = train[:, 1:]\n",
    "train_y = train[:, 0]\n",
    "val_X = val[:, 1:]\n",
    "val_y = val[:, 0]\n",
    "test_X = test[:, 1:]\n",
    "test_y = test[:, 0]\n",
    "\n",
    "print(train_X.shape)\n",
    "print(train_y.shape)\n",
    "print(np.unique(train_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c0f236b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Начинаю 1vA-тренировку для CONTROL...\n",
      "0:\tlearn: 0.6399217\ttest: 0.6430983\tbest: 0.6430983 (0)\ttotal: 92.1ms\tremaining: 3m 4s\n",
      "100:\tlearn: 0.0240822\ttest: 0.1946562\tbest: 0.1796456 (67)\ttotal: 5.73s\tremaining: 1m 47s\n",
      "200:\tlearn: 0.0097568\ttest: 0.2292670\tbest: 0.1796456 (67)\ttotal: 13.1s\tremaining: 1m 56s\n",
      "300:\tlearn: 0.0058616\ttest: 0.2538745\tbest: 0.1796456 (67)\ttotal: 21s\tremaining: 1m 58s\n",
      "400:\tlearn: 0.0042716\ttest: 0.2694959\tbest: 0.1796456 (67)\ttotal: 29.3s\tremaining: 1m 56s\n",
      "500:\tlearn: 0.0034172\ttest: 0.2815836\tbest: 0.1796456 (67)\ttotal: 37.3s\tremaining: 1m 51s\n",
      "bestTest = 0.1796455516\n",
      "bestIteration = 67\n",
      "Shrink model to first 68 iterations.\n",
      "Тренировка весов для CONTROL завершена!\n",
      "\n",
      "Начинаю 1vA-тренировку для BIPOLAR...\n",
      "0:\tlearn: 0.6564961\ttest: 0.6598405\tbest: 0.6598405 (0)\ttotal: 71.4ms\tremaining: 2m 22s\n",
      "100:\tlearn: 0.1501597\ttest: 0.3202584\tbest: 0.3162170 (89)\ttotal: 7.3s\tremaining: 2m 17s\n",
      "200:\tlearn: 0.0798711\ttest: 0.3296942\tbest: 0.3162170 (89)\ttotal: 15.2s\tremaining: 2m 16s\n",
      "300:\tlearn: 0.0503633\ttest: 0.3486733\tbest: 0.3162170 (89)\ttotal: 23.4s\tremaining: 2m 11s\n",
      "400:\tlearn: 0.0358594\ttest: 0.3646566\tbest: 0.3162170 (89)\ttotal: 31.4s\tremaining: 2m 5s\n",
      "500:\tlearn: 0.0266094\ttest: 0.3792708\tbest: 0.3162170 (89)\ttotal: 39.6s\tremaining: 1m 58s\n",
      "bestTest = 0.3162170251\n",
      "bestIteration = 89\n",
      "Shrink model to first 90 iterations.\n",
      "Тренировка весов для BIPOLAR завершена!\n",
      "\n",
      "Начинаю 1vA-тренировку для SCHIZO...\n",
      "0:\tlearn: 0.6694040\ttest: 0.6783148\tbest: 0.6783148 (0)\ttotal: 72.7ms\tremaining: 2m 25s\n",
      "100:\tlearn: 0.1782067\ttest: 0.4110960\tbest: 0.4082890 (97)\ttotal: 8.53s\tremaining: 2m 40s\n",
      "200:\tlearn: 0.1045418\ttest: 0.4126275\tbest: 0.4053637 (156)\ttotal: 16.9s\tremaining: 2m 31s\n",
      "300:\tlearn: 0.0670763\ttest: 0.4166782\tbest: 0.4053637 (156)\ttotal: 25.5s\tremaining: 2m 23s\n",
      "400:\tlearn: 0.0486641\ttest: 0.4294562\tbest: 0.4053637 (156)\ttotal: 33.7s\tremaining: 2m 14s\n",
      "500:\tlearn: 0.0394716\ttest: 0.4379925\tbest: 0.4053637 (156)\ttotal: 42.6s\tremaining: 2m 7s\n",
      "600:\tlearn: 0.0315195\ttest: 0.4428804\tbest: 0.4053637 (156)\ttotal: 50.7s\tremaining: 1m 57s\n",
      "bestTest = 0.4053637187\n",
      "bestIteration = 156\n",
      "Shrink model to first 157 iterations.\n",
      "Тренировка весов для SCHIZO завершена!\n",
      "\n",
      "Начинаю 1vA-тренировку для DEPRESSION...\n",
      "0:\tlearn: 0.6465629\ttest: 0.6486937\tbest: 0.6486937 (0)\ttotal: 33.3ms\tremaining: 1m 6s\n",
      "100:\tlearn: 0.0644317\ttest: 0.2156927\tbest: 0.2089746 (54)\ttotal: 4.94s\tremaining: 1m 32s\n",
      "200:\tlearn: 0.0251333\ttest: 0.2277975\tbest: 0.2089746 (54)\ttotal: 12.7s\tremaining: 1m 53s\n",
      "300:\tlearn: 0.0120746\ttest: 0.2418612\tbest: 0.2089746 (54)\ttotal: 21.1s\tremaining: 1m 58s\n",
      "400:\tlearn: 0.0070298\ttest: 0.2553018\tbest: 0.2089746 (54)\ttotal: 29.1s\tremaining: 1m 56s\n",
      "500:\tlearn: 0.0047995\ttest: 0.2666015\tbest: 0.2089746 (54)\ttotal: 37.7s\tremaining: 1m 52s\n",
      "bestTest = 0.2089745601\n",
      "bestIteration = 54\n",
      "Shrink model to first 55 iterations.\n",
      "Тренировка весов для DEPRESSION завершена!\n",
      "\n"
     ]
    }
   ],
   "source": [
    "m = MyCatBoostModel(2000, list(my_disorders))\n",
    "m.forward(train_X, train_y, val_X, val_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e4f0381c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== CONTROL ===\n",
      "bpm -- 7.391761241653948\n",
      "sdnn -- 8.003032308301362\n",
      "rmssd -- 3.8495148259904406\n",
      "hr_mad -- 77.12499099792296\n",
      "hf -- 2.326219509654539\n",
      "rsa -- 1.304481116476733\n",
      "=== BIPOLAR ===\n",
      "bpm -- 13.527586103523326\n",
      "sdnn -- 27.657339767737266\n",
      "rmssd -- 13.963220415418176\n",
      "hr_mad -- 24.373204120135895\n",
      "hf -- 9.568722890395362\n",
      "rsa -- 10.909926702789956\n",
      "=== SCHIZO ===\n",
      "bpm -- 16.211258567702057\n",
      "sdnn -- 12.566827800635192\n",
      "rmssd -- 21.7459941203426\n",
      "hr_mad -- 22.85964240542689\n",
      "hf -- 15.530747328594487\n",
      "rsa -- 11.085529777298772\n",
      "=== DEPRESSION ===\n",
      "bpm -- 39.23121517528068\n",
      "sdnn -- 21.487919707011386\n",
      "rmssd -- 15.019654364720417\n",
      "hr_mad -- 10.12291718630771\n",
      "hf -- 9.810855448594772\n",
      "rsa -- 4.327438118085046\n"
     ]
    }
   ],
   "source": [
    "importances = m.get_importances()\n",
    "\n",
    "for i in range(len(my_disorders)):\n",
    "    print(f'=== {list(my_disorders)[i]} ===')\n",
    "    for j in range(len(importances[i])):\n",
    "        print(f'{cols[j+1]} -- {importances[i][j]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f9b3b9e7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/wsl/ИИ ПРОЕКТ/ХреНоваяПапка/model.py:118: RuntimeWarning: invalid value encountered in scalar divide\n",
      "  precision = tp/(tp+fp)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(['CONTROL accuracy is 0.9782608695652174',\n",
       "  'BIPOLAR accuracy is 0.782608695652174',\n",
       "  'SCHIZO accuracy is 0.782608695652174',\n",
       "  'DEPRESSION accuracy is 0.9782608695652174'],\n",
       " ['CONTROL f1-score is 0.972972972972973',\n",
       "  'BIPOLAR f1-score is 0.375',\n",
       "  'SCHIZO f1-score is 0.7222222222222222',\n",
       "  'DEPRESSION f1-score is nan'],\n",
       " ['CONTROL precision is 0.9473684210526315',\n",
       "  'BIPOLAR precision is 0.42857142857142855',\n",
       "  'SCHIZO precision is 0.7222222222222222',\n",
       "  'DEPRESSION precision is nan'],\n",
       " ['CONTROL recall is 1.0',\n",
       "  'BIPOLAR recall is 0.3333333333333333',\n",
       "  'SCHIZO recall is 0.7222222222222222',\n",
       "  'DEPRESSION recall is 0.0'],\n",
       " ['CONTROL specificity is 0.9642857142857143',\n",
       "  'BIPOLAR specificity is 0.8918918918918919',\n",
       "  'SCHIZO specificity is 0.8214285714285714',\n",
       "  'DEPRESSION specificity is 1.0'])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m.set_mode('test')\n",
    "\n",
    "results = m.forward(test_X, test_y)\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b8a7e546",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Веса для CONTROL сохранены в \"w__CONTROL.cbm\"\n",
      "Веса для BIPOLAR сохранены в \"w__BIPOLAR.cbm\"\n",
      "Веса для SCHIZO сохранены в \"w__SCHIZO.cbm\"\n",
      "Веса для DEPRESSION сохранены в \"w__DEPRESSION.cbm\"\n"
     ]
    }
   ],
   "source": [
    "m.save_weights(\"w\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
